import torch
from torch import nn


class RewardDecoder(torch.nn.Module):
    def __init__(self, type, num_action):
        super(RewardDecoder, self).__init__()
        if type == 'full':
            # layer for context
            self.layer1 = nn.Sequential(nn.Linear(28 * 28, 64),
                                        nn.ReLU(),
                                        nn.Linear(64, 16),
                                        nn.ReLU()
                                        )
            # layer for feedback
            self.layer2 = nn.Sequential(nn.Linear(28 * 28, 64),
                                        nn.ReLU(),
                                        nn.Linear(64, 16),
                                        nn.ReLU()
                                        )
            # 16 + 16 + num_action
            self.layer3 = nn.Linear(32 + num_action, 1)
            self.type = 'full'

        elif type == 'partial':
            self.layer1 = nn.Sequential(nn.Linear(28 * 28, 64),
                                        nn.ReLU(),
                                        nn.Linear(64, 16),
                                        nn.ReLU())
            self.layer2 = nn.Linear(16, 1)
            self.type = 'partial'

    def forward(self, y, x=None, a=None):
        if self.type == 'full':
            x = x.view(1, 28 * 28)
            x = self.layer1(x)

            y = y.view(1, 28 * 28)
            y = self.layer2(y)

            z = self.layer3(torch.cat((x, y, a), 1))

        elif self.type == 'partial':
            y = y.view(1, 28 * 28)
            y = self.layer1(y)

            z = self.layer2(y)

        output = torch.sigmoid(z)
        return output
